package com.hzhh123.crypto.advice;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.hzhh123.crypto.annotation.EncryptDecrypt;
import com.hzhh123.crypto.annotation.decrypt.*;
import com.hzhh123.crypto.entity.RequestData;
import com.hzhh123.crypto.enums.EncrptyTypeEnum;
import com.hzhh123.crypto.exception.DecryptException;
import com.hzhh123.crypto.exception.ParamException;
import com.hzhh123.crypto.handle.CryptoContext;
import lombok.RequiredArgsConstructor;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.MethodParameter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.servlet.mvc.method.annotation.RequestBodyAdvice;

import java.io.IOException;
import java.io.InputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Type;

/**
 * @author hzhh123
 * @version 1.0
 * @date 2022/4/13 9:21
 * @des
 */
@RequiredArgsConstructor(onConstructor_ = @Autowired)
@ControllerAdvice
public class DecryptRequestBodyAdvice implements RequestBodyAdvice {
    private final ObjectMapper objectMapper;
    private final CryptoContext cryptoContext;


    @Override
    public boolean supports(MethodParameter methodParameter, Type targetType,
                            Class<? extends HttpMessageConverter<?>> converterType) {

        // 方法上加密注解
        Annotation[] methodAnnotations = methodParameter.getMethodAnnotations();
        for (Annotation methodAnnotation : methodAnnotations) {
            if (methodAnnotation instanceof Decrypt
                    || methodAnnotation instanceof RsaDecrypt
                    || methodAnnotation instanceof Base64Decrypt
                    || methodAnnotation instanceof AesDecrypt
                    || methodAnnotation instanceof DesDecrypt
                    || methodAnnotation instanceof EncryptDecrypt) {
                return true;
            }

        }
        // 类上有加密注解
        Annotation[] annotations = methodParameter.getDeclaringClass().getAnnotations();
        for (Annotation annotation : annotations) {
            if (annotation instanceof Decrypt
                    || annotation instanceof RsaDecrypt
                    || annotation instanceof Base64Decrypt
                    || annotation instanceof AesDecrypt
                    || annotation instanceof DesDecrypt
                    || annotation instanceof EncryptDecrypt) {
                return true;
            }

        }
        return false;
    }

    @Override
    public HttpInputMessage beforeBodyRead(HttpInputMessage inputMessage, MethodParameter parameter, Type targetType,
                                           Class<? extends HttpMessageConverter<?>> converterType) throws IOException {
        RequestData requestData = null;
        try {
            String body = IOUtils.toString(inputMessage.getBody());
            requestData = objectMapper.readValue(body, RequestData.class);
            if (requestData == null || StringUtils.isBlank(requestData.getData())) {
                throw new ParamException("参数错误");
            }

        } catch (Exception e) {
            e.printStackTrace();
            throw new ParamException("Decrypting data failed. (解密数据失败)");
        }
        String cryptoType = getCryptoType(parameter);
        if (cryptoType == null) {
            throw new DecryptException("Decryption class not registered. (未注册解密方法类)");
        }
        String decryptStr = cryptoContext.getInstance(cryptoType).decrypt(requestData.getData());
        if (decryptStr == null) {
            throw new ParamException("Decrypting data failed. (解密数据失败)");
        }
        // 获取结果
        return new HttpInputMessage() {

            @Override
            public HttpHeaders getHeaders() {
                return inputMessage.getHeaders();
            }

            @Override
            public InputStream getBody() throws IOException {
                return IOUtils.toInputStream(decryptStr);
            }

        };
    }

    @Override
    public Object afterBodyRead(Object body, HttpInputMessage inputMessage, MethodParameter parameter, Type targetType,
                                Class<? extends HttpMessageConverter<?>> converterType) {
        return body;
    }

    @Override
    public Object handleEmptyBody(Object body, HttpInputMessage inputMessage, MethodParameter parameter, Type targetType,
                                  Class<? extends HttpMessageConverter<?>> converterType) {
        return body;
    }


    private String getCryptoType(MethodParameter returnType) {
        Annotation[] annotations = returnType.getMethodAnnotations();
        String cryptoType = getCryptoType(annotations);
        if (cryptoType == null) {
            annotations = returnType.getDeclaringClass().getAnnotations();
            cryptoType = getCryptoType(annotations);
        }
        return cryptoType;
    }

    private String getCryptoType(Annotation[] annotations) {
        for (Annotation annotation : annotations) {
            if (annotation instanceof Decrypt) {
                return StringUtils.isNotBlank(((Decrypt) annotation).value()) ? ((Decrypt) annotation).value() : ((Decrypt) annotation).type().name();
            }
            if (annotation instanceof AesDecrypt) {
                return EncrptyTypeEnum.AES.name();
            }
            if (annotation instanceof DesDecrypt) {
                return EncrptyTypeEnum.DES.name();
            }
            if (annotation instanceof RsaDecrypt) {
                return EncrptyTypeEnum.RSA.name();
            }
            if (annotation instanceof Base64Decrypt) {
                return EncrptyTypeEnum.BASE64.name();
            }
            if(annotation instanceof EncryptDecrypt){
                EncryptDecrypt encryptDecrypt=(EncryptDecrypt)annotation;
                return encryptDecrypt.value();
            }
        }
        return null;
    }
}
